import gin
import torch
from collections import deque
from itertools import islice
import torch
import pandas as pd
import random
import ast
import spacy 
from sentence_transformers import SentenceTransformer
from model import cudaModel

config = 'cuda.gin'
gin.parse_config_file(config)
cuda_model = cudaModel()

# Global caches
embedding_cache = {}
sentence_embedding_cache = {}

@gin.configurable
def compute_weights(recentness, decay_rate):
    recentness = torch.tensor(recentness, dtype=torch.float32)
    decay_rate = torch.tensor(decay_rate, dtype=torch.float32)
    
    weight = torch.exp(-decay_rate * recentness)
    return weight

def compute_rmse(embedding1, embedding2):
    rmse = torch.sqrt(torch.mean((embedding1 - embedding2) ** 2))
    return rmse

def compute_embedding(doc_id, doc_text):
    if doc_id not in embedding_cache:
        embedding_cache[doc_id] = cuda_model.encode(doc_text)
    return embedding_cache[doc_id]

def get_sentence_embeddings(doc_id, doc_text):
    if doc_id in sentence_embedding_cache:
        return sentence_embedding_cache[doc_id]

    # 1) extract sentences
    sentences = cuda_model.extract_sentences(doc_text)
    # 2) batch encode all sentences at once
    sentence_embs = cuda_model.encode(sentences)
    # 3) store them 
    sentence_embedding_cache[doc_id] = (sentences, sentence_embs)
    return (sentences, sentence_embs)

@gin.configurable
def perturbate_trajectory(row, news_df, summ_df, purturbed_summ_df, curr_summ_id, previous_doc_steps=10, decay_rate=0.3, perturbation_probability=0.8):
    doc_list = ast.literal_eval(row['Docs']) 
    action_list = ast.literal_eval(row['Action'])
    assert len(doc_list) == len(action_list)

    history_embeddings_deque = deque()

    for doc_idx, doc in enumerate(doc_list):
        if action_list[doc_idx] == 'gen_summ':
            query_document = news_df.loc[doc]['News body']
            sentences, sentence_embs = get_sentence_embeddings(doc, query_document)

            # Build the "history_embs" tensor for the last N clicked docs
            # or skip if the deque is empty
            recent_embeddings = list(islice(history_embeddings_deque, previous_doc_steps))
            if recent_embeddings:
                history_embs = torch.stack(recent_embeddings)  # shape: [H, embedding_dim], H up to previous_doc_steps

                # Build weights
                # e.g. [compute_weights(0, decay_rate), compute_weights(1, decay_rate), ...]
                # Make sure length matches `recent_embeddings`.
                weights_vals = [compute_weights(i, decay_rate) for i in range(len(recent_embeddings))]
                weights_t = torch.tensor(weights_vals, device=cuda_model.get_device())

                # Vectorized distance calculation
                # shape of sentence_embs: [num_sentences, dim]
                # shape of history_embs:  [H, dim]
                sq_diff = (sentence_embs[:, None, :] - history_embs[None, :, :]) ** 2  # [num_sentences, H, dim]
                sq_sum = sq_diff.sum(dim=2)        # [num_sentences, H]
                distances = sq_sum.sqrt()          # [num_sentences, H]
                weighted_dists = distances * weights_t  # broadcast over dimension=1 => [num_sentences, H]
                sentences_scores = weighted_dists.sum(dim=1)  # [num_sentences]
            else:
                # If no history, all sentences_scores are 0 or some default
                sentences_scores = torch.zeros(len(sentences), device=cuda_model.get_device())

            argmin_index = torch.argmin(sentences_scores)

            # Decide perturb or not
            perturbation_choice = random.choices(
                [0, 1],
                weights=[1 - perturbation_probability, perturbation_probability],
                k=1
            )

            if perturbation_choice[0] == 1:
                # Create a new summary from the best sentence
                new_summary = sentences[argmin_index]
                new_summary_id = f"S-{curr_summ_id}"
            else:
                # Reuse existing summary
                existing_summary = summ_df.loc[doc_list[doc_idx + 1]]['Summary']
                new_summary = existing_summary
                # Keep ID consistent with your naming scheme
                new_summary_id = f"S-{curr_summ_id}"

            # Insert the summary
            new_summary_df = pd.DataFrame([[new_summary_id, doc_list[doc_idx], row['UserID'], new_summary]],
                                          columns=['SummID', 'NewsID', 'UserID', 'Summary'])
            purturbed_summ_df = pd.concat([purturbed_summ_df, new_summary_df], ignore_index=True)

            # Replace next doc in doc_list if it exists 
            if doc_idx < len(doc_list) - 1:
                doc_list[doc_idx + 1] = new_summary_id

            curr_summ_id += 1

        elif action_list[doc_idx] == 'click':
            doc_embedding = compute_embedding(doc, news_df.loc[doc]['News body'])
            history_embeddings_deque.appendleft(doc_embedding)

            # Optionally limit the deque length
            if len(history_embeddings_deque) > previous_doc_steps:
                history_embeddings_deque.pop()

    return purturbed_summ_df, curr_summ_id, doc_list